import os, sys
from BasicLibraries import *
import Functions.SDG_Matrix.SDGMatrixFunctions as sdgFun
from statsmodels.discrete.discrete_model import Probit, Logit
import Functions.Miscellaneous.utils as ut
from scipy import stats
import colorsys
from colour import Color
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import ListedColormap
import matplotlib.cm as cm
from pysankey import sankey


ONEDRIVE = ''
FIGURES = os.getcwd()+'/Figures/'

plt.rcParams['xtick.major.pad']='8'
LCGreen, LCYellow, LCBlue = ut.utils().get_LCColors()
from matplotlib import rcParams
rcParams['font.family'] = 'serif'
rcParams['font.sans-serif'] = ['Verdana']
#====
nomenclature_dict = {'r&d investments': 'Innovation capacity',
                          'new products': 'Innovation capacity',
                          'association': 'Innovation capacity',
                          'organizational structuring': 'Innovation capacity',
                          'training': 'Risk mitigation',
                          'adoption of standards and rules': 'Risk mitigation',
                          'assessment and measurement': 'Risk mitigation',
                          'modification of procedures': 'Risk mitigation',
                          'asset modification': 'Risk mitigation'}
#%%

class PlotMsc():
    def __init__(self):
        pass
    def StatisticsPerformanceCorrelationRegression(self, dx, dt, metrics):
        data_ = dx.copy()
        data_ = data_.merge(dt[['mrg', 'MarketLeverage']])
        controls = [  'firm_size',  'Total_invested_capital_BS1', 'Tangibility', 'MarketLeverage'] 
        vars_to_log = [ 'Total_invested_capital_BS1']
        dummies=pd.get_dummies(data_.rfyear, drop_first=True)
        data_ = pd.concat((data_, dummies), axis = 1)
        controls = controls+list(dummies.columns)
        dummies=pd.get_dummies(data_['loc'], drop_first=True)
        data_ = pd.concat((data_, dummies), axis = 1)
        controls = controls+list(dummies.columns)
        
        tbl = pd.DataFrame()

        for score_type in metrics:
            D = data_[['gvkey', 'rfyear', 'next_year_performance', score_type]+controls].dropna()
            D = D.sort_values(by = ['gvkey', 'rfyear'])
            if score_type == 'score':
                D = D[D[score_type].abs() < D[score_type].quantile(0.98)]
            D[vars_to_log] = D[vars_to_log].apply(np.log)               
            D = D.replace([-np.inf, np.inf], [np.nan, np.nan]).dropna()
            X = D[[score_type] + controls]
            Y = D['next_year_performance']
            X = X.rename(columns = {score_type: 'distance measure'})
            lm = sm.OLS(Y,X).fit(cov='H1') #, M=sm.robust.norms.HuberT()).fit(cov='H1', cov_type='cluster', cov_kwds={'groups': D['gvkey']})
            L = pd.concat(((lm.params*X.std()/Y.std()).round(2), lm.pvalues.round(2)), axis = 1)
            L.columns = ['coefficients', 'p-values']
            L['coefficients'] = [str(L['coefficients'].iloc[i])+ut.utils().significance(L['p-values'].iloc[i]) for i in range(len(L))]
            tbl = pd.concat((tbl, L['coefficients']), axis = 1)
        tbl.columns = ['Humming distance', 'Cosine similarity',  'Score']
        return tbl, lm
    def MutationAnalysis(self, mutation, save_plot=False):
        s = mutation[mutation.Year > 2015]
        x = s[['Ratio', '# of mutations']].groupby('# of mutations').mean()
        y = s[['Error', '# of mutations']].groupby('# of mutations').mean().rename(columns = {'Error': 'Ratio'})
        x.plot(yerr = y, ls = '-.', marker = 'o', c = 'b', capsize = 4, legend = False)
        plt.xlabel('Number of mutations')
        plt.ylabel('Relative effect of a change in \n relevant features')
        if save_plot:
            plt.savefig(FIGURES+'Fig_mutations.pdf', bbox_inches = 'tight', dpi = 200)

    def PlotMatrix(self, Y,  ax = '', rounded = 1, dpi_=100, significance=False,plotTotal=True,plot_mat=True,
                         vmn = -150, vmx = 300, ctr=0, prec='.3g'):
        X=  Y.transpose().copy()
        X = X.transpose()
        if 'Total' in list(X.columns):
            X = X.drop(columns = ['Total'], index = ['Total'])
        tot = X.sum().sum()
        col_sum = np.round(X.sum(), 1)
        col_sum.name = 'Total'
        row_sum = np.round(X.sum(axis = 1), 1)
        row_sum.name = 'Total'
        X['Total'] = row_sum
        X = X.append(col_sum)
        X.loc['Total' ,'Total'] = np.round(tot, 0)
        X = X.sort_values(by = 'Total', axis = 0, ascending = False)
        X = X.sort_values(by = 'Total', axis = 1, ascending = False)
        cols = list(X.drop(columns = ['Total']).columns)
        rows = list(X.drop(index = ['Total']).index )
        X  = X.loc[rows + ['Total']]
        X  = X[cols+ ['Total']]
        X = X.round(rounded)
        plt.figure(figsize = (14, 14), dpi = dpi_)
        cm = 'RdYlBu' #'PiYG'
        ax = sns.heatmap(X, annot = True, cbar =  False, fmt  = prec, 
                         vmin = vmn,vmax=vmx,
                         center = ctr, linewidths = 0.8, 
                         annot_kws={"color": "k", 'size': '24'},\
                         linecolor = 'black', alpha = 0.6,
                         cmap = cm)

        ax.yaxis.tick_right()
        ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize = 24)
        ax.set_xticklabels(ax.get_xticklabels(), rotation=0 , fontsize = 24)
        ax.hlines([10], color =  'navy', linewidths = 3, *ax.get_xlim())
        ax.vlines([8], color =  'navy', linewidths = 3, ymin = 0, ymax = 11) 
        plt.xlabel('')
        plt.ylabel('')
        return X
    def plot_sankey_mat(self, X, sdgs, actions):
        bk = list(np.ravel([[a+' - SDG '+str(s) for a in actions] for s in sdgs]))
        X = X[bk]
        #== Make sankey plot
        z = X.copy()
        sk = []
        for i in z.columns[::-1]:
            action = i.split(' - ')[0]
            sdg = i.split(' - ')[1]
            repetition = z[i].sum()
            l = [[action, sdg, repetition]]
            sk.extend(l*int(repetition))

        sk = pd.DataFrame(sk, columns = ['Action', 'SDG', 'weight'])
        sk['SDG'] = sk['SDG'].apply(lambda x: x.split('SDG ')[1])
        sk['SDG'] = sk['SDG'].replace(['water & energy'], ['Clean water & energy'])
        sk['SDG'] = sk['SDG'].replace(['biodiversity'], ['Biodiversity'])
        sk['SDG'] = sk['SDG'].replace(['cons. & prod.'], ['Responsible Consumption \n & Production'])
        sk['Action'] = sk['Action'].apply(lambda x: x.capitalize())
        sk['Action'] = sk['Action'].apply(lambda x: x.replace('R&d', 'R&D'))
        actions = list(sk.Action.unique())
        overall = actions + ['SDG '+str(s) for s in sdgs]
        cols = ['#EBEEEE', '#9D9D9D', '#002147', '#008AFF', 
                '#00BEFA', '#B5F0E7',
                '#D24000',   '#c2b300',
                '#92d600', '#08CDAE'][::-1][:len(sk.Action.unique())] + ['#08CDAE']*len(sk.SDG.unique())
        
        col_dict = {overall[i]: cols[i] for i in range(len(cols))}
       
        #==== Assign the SDG color to be the one of the most common action    
        max_mapping = sk.groupby(['Action', 'SDG']).last().reset_index().sort_values(by = ['SDG', 'weight']).groupby('SDG').last().reset_index()[['SDG', 'Action']]
        for n in max_mapping['SDG']:
            col_dict[n] = col_dict[max_mapping[max_mapping.SDG == n]['Action'].iloc[0]]
        #====
        plt.figure(figsize =(40,36))
        ax = plt.subplot()
        sankey(left = sk['Action'], 
               right = sk['SDG'], 
               aspect=200,
               fontsize=40, colorDict=col_dict, ax = ax)
        plt.tight_layout()
        return sk
    def plot_sankey_mat_incidence(self, X, sdgs, actions):
        sk = X.copy()
        sk = pd.DataFrame(sk, columns = ['Action', 'SDG', 'Count'])
        sk = sk.sort_values(by = 'Action', ascending=False).reset_index(drop=True)
        plt.figure(figsize =(18,18))
        #plt.figure(figsize =(30,34))
        ax = plt.subplot()
        sankey(sk['Action'], sk['SDG'], leftWeight=sk['Count'], aspect=140,
            fontsize=44, ax = ax,) 
    def plot_optimal_sankey(self, Xmatrix, env_sdgs, actions):
        m = pd.DataFrame()
        for i in range(len(Xmatrix)):
            l = Xmatrix[i].unstack().reset_index()
            l['titles'] = l['ini']+' - '+l['sdg']
            a = l[[0, 'titles']].set_index('titles').transpose()
            m = pd.concat((a, m))
        bk = list(np.ravel([[a+' - SDG '+str(s) for a in actions] for s in env_sdgs]))
        m = m[bk]
        #== Make sankey plot
        z = m.copy()
        sk = z.sum()/z.sum().sum()
        sk = sk.reset_index()
        sk['Action'] = sk['titles'].apply(lambda x: x.split(' - ')[0])
        sk['SDG'] = sk['titles'].apply(lambda x: x.split(' - ')[1])
        sk = sk.sort_values(by = 'Action', ascending = False).reset_index(drop=True)
        weight = sk[0]

        plt.figure(figsize =(30,34))
        ax = plt.subplot()
        sankey(sk['Action'], sk['SDG'],  aspect=140,
               fontsize=44, ax = ax, rightWeight=weight, leftWeight=weight) 



    def generate_dates(self, start_ = '2014-01-01', end_ = '2019-01-01'):
        a = pd.DataFrame(pd.date_range(start=start_, end=end_, periods=100))
        a[''] = [np.nan]*len(a)
        a = a.set_index(0)
        return a
    def generate_colors(self):
      N = 90
      brightness = 0.7
      hsv = [(i / N, 1, brightness) for i in range(N)]
      colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
      perm = [i for i in np.random.choice(90,50, replace=False)]
      colors = [colors[idx] for idx in perm]
      return colors
                    

    def get_deviation_matrix_binary(self, optima, observed):
        opt, obs, avd = [], [], []
        osdg, oact, esdg, eact = pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
        for i in range(len(optima)):
            if np.shape(optima[i]) == (3,9):
                total_opt = optima[i].sum().sum()
                total_obs = observed[i].sum().sum()
                
                osdg = pd.concat((osdg, optima[i].sum(axis = 1)/total_opt), axis = 1)
                oact = pd.concat((oact, optima[i].sum(axis = 0)/total_opt), axis = 1)
                esdg = pd.concat((esdg, observed[i].sum(axis = 1)/total_obs), axis = 1)
                eact = pd.concat((eact, observed[i].sum(axis = 0)/total_obs), axis = 1)
                
                opt.append(optima[i]/total_opt)
                obs.append(observed[i]/total_obs)
                avd.append(optima[i]/total_opt - observed[i]/total_obs)
                
                column_name = optima[i].columns
                index_name = optima[i].index
            if (i % 500) == 0: print('Iteration:', i)
            
        #### Keep track of the distribution of the deviations
        act = dict()
        chl = dict()
        for activity in ['Risk mitigation', 'Innovation capacity']:
            act[activity] = [avd[i].rename(columns = nomenclature_dict)[activity].mean().mean() for i in range(len(avd))]
        for challenge in ['SDG water & energy', 'SDG cons. & prod.', 'SDG biodiversity']:
            chl[challenge] = [avd[i].loc[challenge].mean() for i in range(len(avd))]       
        act = pd.DataFrame.from_dict(act)
        chl = pd.DataFrame.from_dict(chl)
        
        
        
        opt = np.nanmean(opt, axis = 0)
        obs = np.nanmean(obs, axis = 0)
        avd = np.nanmean(avd, axis = 0)
        
        
        ## Note, you actually don't use these. You use the statistics from get_pvals
        average_optima = pd.DataFrame(opt*100, columns = column_name, index = index_name)
        average_obs = pd.DataFrame(obs*100, columns = column_name, index = index_name)
        average_dev = pd.DataFrame(avd*100, columns = column_name, index = index_name)
        average_deviation = average_optima - average_obs

        return average_optima.transpose(), average_obs.transpose(), average_dev.transpose(), osdg, oact, esdg, eact, act, chl
    
    def get_pvals(self, average_deviation, osdg, oact, esdg, eact):
        '''
        This is the function that generates the statistics for Figure 6
        '''
        #==
        SDG_deviation = (osdg - esdg).mean(axis = 1)*100
        sdg_devDIST = (osdg - esdg)*100
        for s in SDG_deviation.index:
            SDG_deviation.loc[s] =  str(round(SDG_deviation.loc[s], 1))+ut.utils().significance(ttest_ind(osdg.loc[s].round(2), esdg.loc[s].round(2))[1])   
        SDG_deviation = SDG_deviation.reset_index()
        SDG_deviation.columns = ['SDG', r'$\Delta($optimum, observed$)$']

        #==
        Action_deviation = (oact - eact).mean(axis = 1)*100
        act_devDIST = (oact - eact)*100
        for s in Action_deviation.index:
            Action_deviation.loc[s] =  str(round(Action_deviation.loc[s], 1))+ut.utils().significance(ttest_ind(oact.loc[s].round(2), eact.loc[s].round(2))[1])   
        Action_deviation = Action_deviation.reset_index()
        Action_deviation.columns = ['Action', r'$\Delta($optimum, observed$)$']

        #==
        Z = ((oact - eact).mean(axis = 1)*100).rename(index = nomenclature_dict).copy()
        Z = Z.groupby(Z.index).mean()
        E = eact.rename(index = nomenclature_dict).copy()
        E = E.groupby(E.index).mean()
        O = oact.rename(index = nomenclature_dict).copy()
        O = O.groupby(O.index).mean()

        mechanism_deviation = Z.copy()
        mec_devDIST = ((oact - eact)*100).rename(index = nomenclature_dict).copy()
        mec_devDIST = mec_devDIST.groupby(mec_devDIST.index).mean()
        for s in mechanism_deviation.index:
            mechanism_deviation.loc[s] =  str(round(mechanism_deviation.loc[s], 1))+ut.utils().significance(ttest_ind(O.loc[s].round(2), E.loc[s].round(2))[1])   
        mechanism_deviation = mechanism_deviation.reset_index()
        mechanism_deviation.columns = ['Action', r'$\Delta($optimum, observed$)$']

        return SDG_deviation, Action_deviation, mechanism_deviation, sdg_devDIST, act_devDIST, mec_devDIST

    def make_difference_matrix(self, optima, observed):
        average_optima, average_obs, average_deviation, osdg, oact, esdg, eact, act, chl = self.get_deviation_matrix_binary(optima, observed)
        sdg_dev, act_dev, mec_dev, sdg_devDIST, act_devDIST, mec_devDIST = self.get_pvals(average_deviation, osdg, oact, esdg, eact)
        return average_deviation, sdg_dev, act_dev, mec_dev, act, chl, sdg_devDIST, act_devDIST, mec_devDIST

    def plot_matrix(self, X, vmn=0, vmx=10, percentage=True):
        Y = []
        for i in range(len(X)):
            if percentage:
                Y.append(100*X[i]/(X[i].sum().sum()))
            else:
                Y.append(X[i])
               
        Y = np.mean(Y, axis = 0)
            
        Y = pd.DataFrame(Y, columns = X[i].columns, index = X[i].index).transpose()
        
        Xx = self.PlotMatrix(Y, rounded=2, dpi_ = 60, vmn=vmn,vmx=vmx)
        return Xx

    #========================================
    def get_deviation_matrix_by_firm(self, optima_pop, observed_firm):
        '''
        Here the optima should be the population optima list
        The observed is the firm-specific observed
        '''
        average_optima = pd.DataFrame(np.sum(optima_pop, axis = 0).round(1), \
                                         columns =  optima_pop[0].keys(), index = optima_pop[0].index)    
        average_obs = pd.DataFrame(np.sum(observed_firm, axis = 0).round(1), \
                                         columns =  optima_pop[0].keys(), index =  optima_pop[0].index)    
        
        average_optima = average_optima/(average_optima.sum().sum())
        average_obs = average_obs/(average_obs.sum().sum())
        average_deviation = (average_optima - average_obs)*100
        return average_optima.transpose(), average_obs.transpose(), average_deviation.transpose()
    

    
    def plot_matrix_by_firm(self, M, save_plot=False, FIGNAME='none'):

        Xx = self.PlotMatrix(M, rounded=2, dpi_ = 60, vmn=-5,vmx=5)
        if save_plot:
            plt.savefig(FIGURES+FIGNAME+'.pdf', bbox_inches = 'tight', dpi = 200)
        return Xx
    